import argparse
import tensorflow as tf
from tensorflow import keras as K
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans


parser = argparse.ArgumentParser()
parser.add_argument("--n_centre", type=int, default=10)
parser.add_argument("--n_per", type=int, default=512)
parser.add_argument("--n_class", type=int, default=10)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=64)
args = parser.parse_args()


def euclidean(A, B=None, sqrt=False):
    if (B is None) or (B is A):
        aTb = tf.matmul(A, tf.transpose(A))
        aTa = bTb = tf.linalg.diag_part(aTb)
    else:
        aTb = tf.matmul(A, tf.transpose(B))
        aTa = tf.linalg.diag_part(tf.matmul(A, tf.transpose(A)))
        bTb = tf.linalg.diag_part(tf.matmul(B, tf.transpose(B)))
    D = aTa[:, None] - 2.0 * aTb + bTb[None, :]
    D = tf.maximum(D, 0.0)
    if sqrt:
        mask = tf.cast(tf.equal(D, 0.0), "float32")
        D = D + mask * 1e-16
        D = tf.math.sqrt(D)
        D = D * (1.0 - mask)
    return D


# data
mnist = K.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0   # [n, 28, 28], [n]
x_train = tf.reshape(x_train, [-1, 784])
x_test = tf.reshape(x_test, [-1, 784])
print("data:", type(x_train), x_train.shape, y_test.shape)

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(args.batch_size)
test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(args.batch_size)


class RBF_Cell(K.Model):
    def __init__(self, centre, var):
        """RBF with Gaussian kernel
        centre: tensor, with same shape as input x
        var: scalar, variance of Gaussian
        """
        super(RBF_Cell, self).__init__()
        self.centre = tf.cast(centre[None, :], "float32")  # [1, d]
        self.coef = tf.cast(- 0.5 / var, "float32")

    def call(self, x):
        """[n, 1] -> [n, 1]"""
        return tf.math.exp(self.coef * euclidean(x, self.centre))


class RBF_Net(K.Model):
    def __init__(self, centres, variances):
        """X -> [rbf_i(X; C_i)] -> Y
        centres: [n_centre, d]
        variances: [n_centre]
        """
        super(RBF_Net, self).__init__()
        self.rbf_list = [RBF_Cell(c, v) for c, v in zip(centres, variances)]
        self.fc = K.layers.Dense(args.n_class, input_shape=[centres.shape[0]])

    def call(self, x):
        """[n, d] -> [n, #centres] -> [n, n_class]"""
        feat = tf.concat([rbf(x) for rbf in self.rbf_list], axis=1)
        logit = self.fc(feat)
        return feat, logit


# clustering
#centres, variances = [], []
#for c in range(args.n_class):
#    Xc = x_train[y_train == c]
#    EX = tf.reduce_mean(Xc, 0, keepdims=True)
#    centres.append(EX)
#    # Var(X) = E(X - EX)^2
#    variances.append(tf.reduce_mean(euclidean(Xc, EX)))

#centres, variances = np.vstack(centres), np.asarray(variances)
#np.save("centres.npy", centres)
#np.save("variances.npy", variances)
centres = np.load("centres.npy")
variances = np.load("variances.npy")
print("centres:", centres.shape, ", variances:", variances.shape)


# model
model = RBF_Net(centres, variances)
criterion = K.losses.SparseCategoricalCrossentropy(
    from_logits=True)  # `Sparse` for NOT one-hot
optimizer = K.optimizers.Adam()

train_loss = K.metrics.Mean(name='train_loss')
train_accuracy = K.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = K.metrics.Mean(name='test_loss')
test_accuracy = K.metrics.SparseCategoricalAccuracy(name='test_accuracy')


#@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        _, pred = model(images)
        loss = criterion(labels, pred)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, pred)
    return loss


#@tf.function
def test_step(images, labels):
    _, pred = model(images)
    t_loss = criterion(labels, pred)

    test_loss(t_loss)
    test_accuracy(labels, pred)

    pred = tf.argmax(pred, axis=1)
    labels = tf.cast(labels, "int64")
    n_correct = tf.reduce_sum(tf.cast(pred == labels, "float32"))
    return n_correct


loss_list, acc_list = [], []
for epoch in range(args.epoch):
    # 在下一个epoch开始时，重置评估指标
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        # images = tf.image.resize(images, [224, 224])
        # images = tf.tile(images, tf.constant([1, 1, 1, 3]))
        l = train_step(images, labels)
        loss_list.append(l.numpy())

    n_corr = 0
    for images, labels in test_ds:
        # images = tf.image.resize(images, [224, 224])
        # images = tf.tile(images, tf.constant([1, 1, 1, 3]))
        _n_corr = test_step(images, labels)
        n_corr += _n_corr.numpy()
    acc_list.append(n_corr / y_test.shape[0])

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
                          train_loss.result(),
                          train_accuracy.result()*100,
                          test_loss.result(),
                          test_accuracy.result()*100))


# plot loss
fig = plt.figure()
plt.title("loss")
plt.plot(np.arange(len(loss_list)), loss_list)
# plt.show()
fig.savefig("loss.png")

# plot accuracy
fig = plt.figure()
plt.title("accuracy")
plt.plot(np.arange(len(acc_list)), acc_list)
# plt.show()
fig.savefig("accuracy.png")


# T-SNE
fea_list = []
for i, (images, labels) in enumerate(test_ds):
    fea, _ = model(images)
    fea_list.append(fea.numpy())
    if i > 5:
        break

F = np.vstack(fea_list)
tsne = TSNE(n_components=2, init="pca", random_state=0)
F = tsne.fit_transform(F)
x_min, x_max = np.min(F, 0), np.max(F, 0)
F = (F - x_min) / (x_max - x_min)
fig = plt.figure()
plt.title("T-SNE")
for i in range(F.shape[0]):
    plt.text(F[i, 0], F[i, 1], str(y_test[i]),
             color=plt.cm.Set1(y_test[i] / 10.))
fig.savefig("tsne.png")
